/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef EXAMPLES_MATRIX_MATMUL_MNDB_CUSTOM_IMPL_H
#define EXAMPLES_MATRIX_MATMUL_MNDB_CUSTOM_IMPL_H
#include "kernel_operator.h"
#define ASCENDC_CUBE_ONLY
#include "lib/matmul_intf.h"

namespace CustomMatmulMndb {
// init norm template config specified IterateOrder::ORDER_M/IterateOrder::ORDER_N ScheduleType::OUTER_PRODUCT for mndb
constexpr static MatmulConfigMode configMode = MatmulConfigMode::CONFIG_NORM;
constexpr static MatmulFuncParams mFuncParams{false, false, false, false, 0, IterateOrder::ORDER_M, ScheduleType::OUTER_PRODUCT, true, true};
constexpr static MatmulFuncParams nFuncParams{false, false, false, false, 0, IterateOrder::ORDER_N, ScheduleType::OUTER_PRODUCT, true, true};
constexpr static MatmulConfig CFG_NORM_OUTER_PRODUCT_M = GetMMConfig<configMode>(mFuncParams);
constexpr static MatmulConfig CFG_NORM_OUTER_PRODUCT_N = GetMMConfig<configMode>(nFuncParams);

// init mdl template config specified IterateOrder::ORDER_M/IterateOrder::ORDER_N ScheduleType::OUTER_PRODUCT for mndb
constexpr static MatmulConfigMode configModeMDL = MatmulConfigMode::CONFIG_MDL;
constexpr static MatmulFuncParams funcParamsOrderM{false, false, false, false, 0, IterateOrder::ORDER_M, ScheduleType::OUTER_PRODUCT, true, true};
constexpr static MatmulFuncParams funcParamsOrderN{false, false, false, false, 0, IterateOrder::ORDER_N, ScheduleType::OUTER_PRODUCT, true, true};
constexpr static MatmulConfig CFG_MDL_OUTER_PRODUCT_ORDER_M = GetMMConfig<configModeMDL>(funcParamsOrderM);
constexpr static MatmulConfig CFG_MDL_OUTER_PRODUCT_ORDER_N = GetMMConfig<configModeMDL>(funcParamsOrderN);


template<typename aType, typename bType, typename cType, typename biasType, int32_t mndbMode>
class MatmulMndbKernel {
    public:
        __aicore__ inline MatmulMndbKernel(){};
        __aicore__ inline void Init(GM_ADDR a, GM_ADDR b, GM_ADDR bias, GM_ADDR c, GM_ADDR workspace,
            const TCubeTiling& tiling);
        __aicore__ inline void Process(AscendC::TPipe* pipe);

    private:
        __aicore__ inline void CalcOffset(
            int32_t blockIdx, TCubeTiling& param, int32_t& offsetA, int32_t& offsetB,
            int32_t& offsetC, int32_t& offsetBias);

        AscendC::GlobalTensor<aType> aGlobal;
        AscendC::GlobalTensor<bType> bGlobal;
        AscendC::GlobalTensor<cType> cGlobal;
        AscendC::GlobalTensor<biasType> biasGlobal;
        TCubeTiling tiling;
};

template<typename aType, typename bType, typename cType, typename biasType, int32_t mndbMode>
__aicore__ inline void MatmulMndbKernel<aType, bType, cType, biasType, mndbMode>::Init(GM_ADDR a,
        GM_ADDR b, GM_ADDR bias, GM_ADDR c, GM_ADDR workspace, const TCubeTiling& tiling)
{
    this->tiling = tiling;
    aGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ aType*>(a), tiling.M * tiling.Ka);
    bGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ bType*>(b), tiling.Kb * tiling.N);
    cGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ cType*>(c), tiling.M * tiling.N);
    biasGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ biasType*>(bias), tiling.N);

    int32_t offsetA = 0;
    int32_t offsetB = 0;
    int32_t offsetC = 0;
    int32_t offsetBias = 0;
    CalcOffset(AscendC::GetBlockIdx(), this->tiling, offsetA, offsetB, offsetC, offsetBias);
    aGlobal = aGlobal[offsetA];
    bGlobal = bGlobal[offsetB];
    cGlobal = cGlobal[offsetC];
    biasGlobal = biasGlobal[offsetBias];
    if(GetSysWorkSpacePtr() == nullptr){
        return;
    }
}

template<typename aType, typename bType, typename cType, typename biasType, int32_t mndbMode>
__aicore__ inline void MatmulMndbKernel<aType, bType, cType, biasType, mndbMode>::Process(AscendC::TPipe* pipe)
{

    if constexpr (mndbMode == 1) {
        AscendC::Matmul<AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, aType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, bType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, cType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, biasType>, CFG_NORM_OUTER_PRODUCT_M> matmulObj;
        REGIST_MATMUL_OBJ(pipe, GetSysWorkSpacePtr(), matmulObj, &(this->tiling));
        matmulObj.SetTensorA(aGlobal, false);
        matmulObj.SetTensorB(bGlobal, false);
        if (tiling.isBias) {
            matmulObj.SetBias(biasGlobal);
        }
        matmulObj.IterateAll(cGlobal);
        matmulObj.End();
    } else if constexpr (mndbMode == 2) {
        AscendC::Matmul<AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, aType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, bType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, cType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, biasType>, CFG_NORM_OUTER_PRODUCT_N> matmulObj;
        REGIST_MATMUL_OBJ(pipe, GetSysWorkSpacePtr(), matmulObj, &(this->tiling));
        matmulObj.SetTensorA(aGlobal, false);
        matmulObj.SetTensorB(bGlobal, false);
        if (tiling.isBias) {
            matmulObj.SetBias(biasGlobal);
        }
        matmulObj.IterateAll(cGlobal);
        matmulObj.End();
    } else if constexpr (mndbMode == 3) {
        AscendC::Matmul<AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, aType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, bType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, cType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, biasType>, CFG_MDL_OUTER_PRODUCT_ORDER_M> matmulObj;
        REGIST_MATMUL_OBJ(pipe, GetSysWorkSpacePtr(), matmulObj, &(this->tiling));
        matmulObj.SetTensorA(aGlobal, false);
        matmulObj.SetTensorB(bGlobal, false);
        if (tiling.isBias) {
            matmulObj.SetBias(biasGlobal);
        }
        matmulObj.IterateAll(cGlobal);
        matmulObj.End();
    } else if constexpr (mndbMode == 4) {
        AscendC::Matmul<AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, aType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, bType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, cType>,
        AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, biasType>, CFG_MDL_OUTER_PRODUCT_ORDER_N> matmulObj;
        REGIST_MATMUL_OBJ(pipe, GetSysWorkSpacePtr(), matmulObj, &(this->tiling));
        matmulObj.SetTensorA(aGlobal, false);
        matmulObj.SetTensorB(bGlobal, false);
        if (tiling.isBias) {
            matmulObj.SetBias(biasGlobal);
        }
        matmulObj.IterateAll(cGlobal);
        matmulObj.End();
    }
}


template<typename aType, typename bType, typename cType, typename biasType, int32_t mndbMode>
__aicore__ inline void MatmulMndbKernel<aType, bType, cType, biasType, mndbMode>::CalcOffset(
    int32_t blockIdx, TCubeTiling& param,
    int32_t& offsetA, int32_t& offsetB, int32_t& offsetC, int32_t& offsetBias)
{
    auto temp0 = AscendC::Ceil(param.M, param.singleCoreM);
    auto temp1 = AscendC::Ceil(param.N, param.singleCoreN);
    auto temp2 = AscendC::Ceil(param.Ka, param.singleCoreK);

    auto divideKCoreNum = param.usedCoreNum / temp2;

    auto mCoreIndex = (blockIdx % divideKCoreNum) % temp0;
    auto nCoreIndex = (blockIdx % divideKCoreNum) / temp0;
    auto subKIndex = blockIdx / divideKCoreNum;

    offsetA = mCoreIndex * param.Ka * param.singleCoreM + subKIndex * param.singleCoreK;
    offsetB = subKIndex * param.singleCoreK * param.N + nCoreIndex * param.singleCoreN;
    offsetC = mCoreIndex * param.N * param.singleCoreM + nCoreIndex * param.singleCoreN;
    offsetBias = nCoreIndex * param.singleCoreN;

    int32_t gmUseM = param.M - mCoreIndex * param.singleCoreM;
    param.singleCoreM = gmUseM < param.singleCoreM ? gmUseM : param.singleCoreM;

    int32_t gmUseN = param.N - nCoreIndex * param.singleCoreN;
    param.singleCoreN = gmUseN < param.singleCoreN ? gmUseN : param.singleCoreN;

    int32_t gmUseK = param.Ka - subKIndex * param.singleCoreK;
    param.singleCoreK = gmUseK < param.singleCoreK ? gmUseK : param.singleCoreK;
}
}  // namespace CustomMatmulMndb
#endif // EXAMPLES_MATRIX_MATMUL_MNDB_CUSTOM_TILING_H